Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[nnx] remove flagslib #3733

Merged
merged 1 commit into from
Mar 13, 2024
Merged

[nnx] remove flagslib #3733

merged 1 commit into from
Mar 13, 2024

Conversation

cgarciae
Copy link
Collaborator

@cgarciae cgarciae commented Mar 1, 2024

What does this PR do?

This PR removes the flagslib module is favor of a Module.set_attributes method that recursively sets the attributes of all Modules in the Module graph. Basically, flags are now just Module attributes plus a mechanism to recursively set them.

@cgarciae cgarciae force-pushed the nnx-remove-flags-context branch 3 times, most recently from 866f60f to 67202a7 Compare March 2, 2024 04:45
@cgarciae cgarciae changed the base branch from main to nnx-cleanup-graph-utils March 2, 2024 19:36
Copy link

Check out this pull request on  ReviewNB

See visual diffs & provide feedback on Jupyter Notebooks.


Powered by ReviewNB

Base automatically changed from nnx-cleanup-graph-utils to main March 8, 2024 02:27
@codecov-commenter
Copy link

codecov-commenter commented Mar 8, 2024

Codecov Report

All modified and coverable lines are covered by tests ✅

Project coverage is 58.77%. Comparing base (ce8a3c7) to head (2206c40).
Report is 2 commits behind head on main.

Additional details and impacted files
@@            Coverage Diff             @@
##             main    #3733      +/-   ##
==========================================
+ Coverage   58.43%   58.77%   +0.34%     
==========================================
  Files         102      101       -1     
  Lines       12365    12409      +44     
==========================================
+ Hits         7225     7293      +68     
+ Misses       5140     5116      -24     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@@ -0,0 +1,188 @@
# Copyright 2024 The Flax Authors.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This file was deleted after #3742. I think you need to rebase.

@@ -109,6 +109,12 @@ def _meta_call(cls: tp.Type[M], *args, **kwargs) -> M:
vars(module)[field.name] = None
continue

if 'nnx_variable_constructor' not in field.metadata:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was deleted as well in #3742.

@@ -68,6 +68,17 @@ def call(cls: tp.Type[P], *args: tp.Any, **kwargs: tp.Any) -> P:
with _mutable(obj), _initializing(obj):
obj.__init__(*args, **kwargs)

if dataclasses.is_dataclass(obj):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was deleted as well in #3742.

@@ -492,31 +516,31 @@ def __init__(self):

class TestModuleDataclass:
def test_basic(self):
@dataclasses.dataclass
@nnx.dataclass
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This test was modified by #3742.

@@ -48,12 +47,12 @@ def __init__(self, y) -> None:
pytree.x = 4

def test_immutable_pytree_dataclass(self):
@dataclasses.dataclass(frozen=True)
@nnx.dataclass(frozen=True)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This file was modified by #3742.

@@ -12,81 +12,16 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import dataclasses
import numpy as np
import warnings
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These changes don't look related to this PR. Maybe rebasing will fix this?

@@ -31,12 +31,9 @@
from jax import eval_shape, lax
from jax.core import ShapedArray

import opt_einsum

from flax.core import meta
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same here, rebase

self, *filters: filterlib.Filter, **attributes: tp.Any
) -> None:
"""Sets the attributes of nested Modules including the current Module.
If the attribute is not found in the Module, it is ignored.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

would you consider a flag that would raise an error if the attribute isn't found?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

@cgarciae cgarciae force-pushed the nnx-remove-flags-context branch 3 times, most recently from 544091e to 41b1e2c Compare March 13, 2024 15:48
@cgarciae
Copy link
Collaborator Author

Cleaned the PR to remove the spurious changes.

@cgarciae cgarciae force-pushed the nnx-remove-flags-context branch 2 times, most recently from e5a2b4e to db0e96a Compare March 13, 2024 17:51
@copybara-service copybara-service bot merged commit 0280160 into main Mar 13, 2024
21 checks passed
@copybara-service copybara-service bot deleted the nnx-remove-flags-context branch March 13, 2024 21:02
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants